"""
Phase 7: Dynamic Effects — Jordà Local Projections
====================================================
Estimate impulse responses at horizons h=0,...,5 for:
  y_{t+h} - y_{t-1} = α + β_h · x_t + controls_t + ε_{t+h}

Two specifications:
  A) Reduced form: x = log_predicted_demo_inflows (gravity instrument)
  B) Direct demographics: x = Z_1, Z_2, Z_3

Outcomes: Δlog(K/L), Investment/GDP, ΔTFP, GDP growth, MPK

This reveals whether the puzzling negative contemporaneous investment sign
reverses at longer horizons (J-curve from demographic capital).
"""

import pandas as pd
import numpy as np
from pathlib import Path
import sys
import warnings
warnings.filterwarnings('ignore')

PROJECT_DIR = Path(__file__).resolve().parent.parent
ROOT_DIR = PROJECT_DIR.parent
sys.path.insert(0, str(ROOT_DIR / "multilateral" / "src"))
from model import PanelGLS

DATA = PROJECT_DIR / "data" / "processed"
OUT_TABLES = PROJECT_DIR / "output" / "tables"
OUT_FIGURES = PROJECT_DIR / "output" / "figures"
OUT_TABLES.mkdir(parents=True, exist_ok=True)
OUT_FIGURES.mkdir(parents=True, exist_ok=True)

MAX_HORIZON = 5
CONTROLS = ['fiscal_bal_gdp', 'nfa_gdp_lag', 'log_rel_opw', 'kaopen']


def stars(p):
    if p < 0.01: return '***'
    if p < 0.05: return '**'
    if p < 0.1: return '*'
    return ''


def build_forward_outcomes(df, y_var, max_h):
    """
    For each country-year, compute cumulative outcome at horizon h.

    For flow variables (investment/GDP, MPK): use level at t+h
    For growth variables (delta_log_kl, delta_log_tfp, rgdp_growth):
      use cumulative sum from t to t+h (cumulative impulse response)
    """
    df = df.sort_values(['iso3', 'year']).copy()

    # Determine if this is a "growth" variable (cumulate) or "level" variable (lead)
    growth_vars = ['delta_log_kl', 'delta_log_tfp', 'rgdp_growth']
    is_growth = y_var in growth_vars

    for h in range(0, max_h + 1):
        col_name = f'{y_var}_h{h}'
        if is_growth and h > 0:
            # Cumulative: sum of y from t to t+h
            # First create leads, then cumulate
            df[col_name] = (
                df.groupby('iso3')[y_var]
                .transform(lambda s: s.rolling(window=h + 1, min_periods=h + 1).sum().shift(-h))
            )
        else:
            # Level or h=0: just lead by h periods
            df[col_name] = df.groupby('iso3')[y_var].shift(-h)

    return df


def run_local_projection(df, y_var_h, x_vars, controls, label):
    """Run a single local projection regression."""
    all_vars = [y_var_h] + x_vars + controls + ['iso3', 'year']
    sub = df[[c for c in all_vars if c in df.columns]].dropna()
    actual_x = [v for v in x_vars + controls if v in sub.columns]

    if len(sub) < 50:
        return None

    gls = PanelGLS()
    gls.fit(sub[y_var_h].values, sub[actual_x].values,
            sub['iso3'].values, sub['year'].values)

    row = {'label': label, 'n_obs': gls.n_obs, 'r_squared': gls.r_squared}
    for i, name in enumerate(actual_x):
        row[f'{name}_coef'] = gls.beta[i]
        row[f'{name}_se'] = gls.se[i]
        row[f'{name}_p'] = gls.pvalues[i]

    return row


def run_lp_sequence(df, y_var, x_vars, key_var, controls, spec_label):
    """
    Run local projections for h=0,...,MAX_HORIZON.
    Returns list of {horizon, coef, se, p, ci_lo, ci_hi, n_obs, r2}.
    """
    # Build forward outcomes
    df_lp = build_forward_outcomes(df, y_var, MAX_HORIZON)

    irf = []
    for h in range(MAX_HORIZON + 1):
        y_h = f'{y_var}_h{h}'
        result = run_local_projection(df_lp, y_h, x_vars, controls,
                                       f'{spec_label} h={h}')
        if result is None:
            irf.append({'horizon': h, 'coef': np.nan, 'se': np.nan,
                        'p': np.nan, 'n_obs': 0})
            continue

        coef = result.get(f'{key_var}_coef', np.nan)
        se = result.get(f'{key_var}_se', np.nan)
        p = result.get(f'{key_var}_p', np.nan)
        irf.append({
            'horizon': h,
            'coef': coef,
            'se': se,
            'p': p,
            'ci_lo': coef - 1.96 * se,
            'ci_hi': coef + 1.96 * se,
            'n_obs': result['n_obs'],
            'r_squared': result['r_squared'],
        })

    return irf


def make_ascii_irf(irf_data, title, width=60):
    """Create ASCII impulse response plot."""
    lines = [title, '=' * len(title), '']

    # Determine scale
    all_vals = []
    for pt in irf_data:
        if not np.isnan(pt['coef']):
            all_vals.extend([pt['ci_lo'], pt['ci_hi']])
    if not all_vals:
        return title + '\n  (no data)\n'

    vmin, vmax = min(all_vals), max(all_vals)
    # Include zero
    vmin = min(vmin, 0)
    vmax = max(vmax, 0)
    span = vmax - vmin if vmax > vmin else 1

    def pos(val):
        return int((val - vmin) / span * (width - 1))

    zero_pos = pos(0)

    # Header
    lines.append(f'  h  {"coef":>10s} {"SE":>8s} {"p":>6s}  {"N":>5s}  Plot')
    lines.append(f'  -  {"----":>10s} {"--":>8s} {"---":>6s}  {"---":>5s}  ' + '-' * width)

    for pt in irf_data:
        h = pt['horizon']
        if np.isnan(pt['coef']):
            lines.append(f'  {h}  {"n/a":>10s}')
            continue

        sig = stars(pt['p'])
        coef_str = f"{pt['coef']:.5f}{sig}"
        se_str = f"{pt['se']:.5f}"
        p_str = f"{pt['p']:.4f}"
        n_str = f"{pt['n_obs']}"

        # Build plot line
        plot = [' '] * width
        plot[zero_pos] = '|'

        c_pos = pos(pt['coef'])
        lo_pos = max(0, pos(pt['ci_lo']))
        hi_pos = min(width - 1, pos(pt['ci_hi']))

        # CI range
        for i in range(lo_pos, hi_pos + 1):
            plot[i] = '-'
        # Point estimate
        plot[c_pos] = '*' if pt['p'] < 0.05 else 'o' if pt['p'] < 0.1 else '·'
        # Zero line
        if lo_pos <= zero_pos <= hi_pos:
            pass  # don't overwrite
        else:
            plot[zero_pos] = '|'

        lines.append(f'  {h}  {coef_str:>10s} {se_str:>8s} {p_str:>6s}  {n_str:>5s}  {"".join(plot)}')

    lines.append(f'  {"":3s} {"":>10s} {"":>8s} {"":>6s}  {"":>5s}  ' + '-' * width)
    lines.append(f'  {"":3s} {vmin:>{width//2}.4f}{" " * (width - width//2 - 10)}{vmax:.4f}')
    lines.append(f'  (* p<0.05, o p<0.1, · p>0.1, | = zero)')
    lines.append('')
    return '\n'.join(lines)


def main():
    print("=" * 70)
    print("PHASE 7: DYNAMIC EFFECTS — JORDÀ LOCAL PROJECTIONS")
    print("=" * 70)

    df = pd.read_csv(DATA / "deepening_panel.csv")
    print(f"Panel: {len(df)} obs, {df['iso3'].nunique()} countries")

    outcomes = {
        'delta_log_kl': 'Capital Deepening (cumulative Δlog K/L)',
        'gross_fixed_investment_gdp': 'Investment/GDP (level at t+h)',
        'delta_log_tfp': 'TFP Growth (cumulative Δlog TFP)',
        'rgdp_growth': 'GDP Growth (cumulative)',
        'mpk_proxy': 'MPK Proxy (level at t+h)',
    }

    all_irfs = []  # For CSV export
    all_plots = []  # For markdown

    # ── Specification A: Reduced Form (predicted demographic inflows) ──
    print("\n" + "=" * 70)
    print("SPECIFICATION A: REDUCED FORM (Predicted Demographic Inflows)")
    print("=" * 70)

    for y_var, y_label in outcomes.items():
        if y_var not in df.columns:
            print(f"\nSkipping {y_label}: variable missing")
            continue
        if 'log_predicted_demo_inflows' not in df.columns:
            print("No instrument variable available")
            break

        print(f"\n--- {y_label} ---")
        irf = run_lp_sequence(
            df, y_var,
            x_vars=['log_predicted_demo_inflows'],
            key_var='log_predicted_demo_inflows',
            controls=CONTROLS,
            spec_label=f'RF: demo_inflows → {y_var}'
        )

        for pt in irf:
            pt['outcome'] = y_var
            pt['outcome_label'] = y_label
            pt['spec'] = 'A: Reduced Form'
            pt['key_var'] = 'log_predicted_demo_inflows'
        all_irfs.extend(irf)

        plot = make_ascii_irf(irf, f'A: Predicted Demo Inflows → {y_label}')
        print(plot)
        all_plots.append(plot)

    # ── Specification B: Direct Demographics (Z₁, Z₂, Z₃) ──
    print("\n" + "=" * 70)
    print("SPECIFICATION B: DIRECT DEMOGRAPHICS (Z₁, Z₂, Z₃)")
    print("=" * 70)

    for y_var, y_label in outcomes.items():
        if y_var not in df.columns:
            continue

        print(f"\n--- {y_label} ---")
        # Run for each Z component separately to get clean IRFs
        for z_var, z_label in [('Z_1', 'Z₁'), ('Z_2', 'Z₂'), ('Z_3', 'Z₃')]:
            irf = run_lp_sequence(
                df, y_var,
                x_vars=['Z_1', 'Z_2', 'Z_3'],
                key_var=z_var,
                controls=CONTROLS,
                spec_label=f'Direct: {z_var} → {y_var}'
            )

            for pt in irf:
                pt['outcome'] = y_var
                pt['outcome_label'] = y_label
                pt['spec'] = f'B: Direct ({z_label})'
                pt['key_var'] = z_var
            all_irfs.extend(irf)

            plot = make_ascii_irf(irf, f'B: {z_label} → {y_label}')
            print(plot)
            all_plots.append(plot)

    # ── Specification C: OECD subsample (where instruments are strong) ──
    print("\n" + "=" * 70)
    print("SPECIFICATION C: OECD SUBSAMPLE — REDUCED FORM")
    print("=" * 70)

    OECD = {
        'AUS', 'AUT', 'BEL', 'CAN', 'CHE', 'CHL', 'COL', 'CRI', 'CZE', 'DEU',
        'DNK', 'ESP', 'EST', 'FIN', 'FRA', 'GBR', 'GRC', 'HUN', 'IRL', 'ISL',
        'ISR', 'ITA', 'JPN', 'KOR', 'LTU', 'LUX', 'LVA', 'MEX', 'NLD', 'NOR',
        'NZL', 'POL', 'PRT', 'SVK', 'SVN', 'SWE', 'TUR', 'USA',
    }
    df_oecd = df[df['iso3'].isin(OECD)].copy()
    print(f"OECD subsample: {len(df_oecd)} obs, {df_oecd['iso3'].nunique()} countries")

    for y_var, y_label in outcomes.items():
        if y_var not in df_oecd.columns:
            continue
        if 'log_predicted_demo_inflows' not in df_oecd.columns:
            break

        print(f"\n--- {y_label} ---")
        irf = run_lp_sequence(
            df_oecd, y_var,
            x_vars=['log_predicted_demo_inflows'],
            key_var='log_predicted_demo_inflows',
            controls=CONTROLS,
            spec_label=f'OECD RF: demo_inflows → {y_var}'
        )

        for pt in irf:
            pt['outcome'] = y_var
            pt['outcome_label'] = y_label
            pt['spec'] = 'C: OECD Reduced Form'
            pt['key_var'] = 'log_predicted_demo_inflows'
        all_irfs.extend(irf)

        plot = make_ascii_irf(irf, f'C (OECD): Predicted Demo Inflows → {y_label}')
        print(plot)
        all_plots.append(plot)

    # ── Save results ──
    print("\n--- Saving results ---")

    # CSV with all IRF data
    irfs_df = pd.DataFrame(all_irfs)
    irfs_df.to_csv(OUT_TABLES / "local_projections_results.csv", index=False)

    # Markdown summary
    with open(OUT_TABLES / "local_projections.md", 'w') as f:
        f.write("# Table 7: Jordà Local Projections — Dynamic Effects\n\n")
        f.write(f"Horizons h=0,...,{MAX_HORIZON}. ")
        f.write("Growth variables (Δlog K/L, ΔTFP, GDP growth) are cumulated. ")
        f.write("Level variables (Investment/GDP, MPK) are leads.\n\n")

        # Summary table: key coefficients by horizon
        f.write("## A: Reduced Form (Predicted Demographic Inflows → Outcome)\n\n")
        f.write("| Outcome | h=0 | h=1 | h=2 | h=3 | h=4 | h=5 |\n")
        f.write("|---------|-----|-----|-----|-----|-----|-----|\n")

        for y_var, y_label in outcomes.items():
            row_irfs = [pt for pt in all_irfs
                        if pt['outcome'] == y_var and pt['spec'] == 'A: Reduced Form']
            if not row_irfs:
                continue
            cells = []
            for h in range(MAX_HORIZON + 1):
                pt = next((p for p in row_irfs if p['horizon'] == h), None)
                if pt and not np.isnan(pt['coef']):
                    cells.append(f"{pt['coef']:.4f}{stars(pt['p'])}")
                else:
                    cells.append('')
            f.write(f"| {y_label} | {' | '.join(cells)} |\n")
        f.write("\n")

        # Direct demographics table (Z₁ only for space)
        f.write("## B: Direct Demographics (Z₁ coefficient by horizon)\n\n")
        f.write("| Outcome | h=0 | h=1 | h=2 | h=3 | h=4 | h=5 |\n")
        f.write("|---------|-----|-----|-----|-----|-----|-----|\n")

        for y_var, y_label in outcomes.items():
            row_irfs = [pt for pt in all_irfs
                        if pt['outcome'] == y_var and pt['spec'] == 'B: Direct (Z₁)']
            if not row_irfs:
                continue
            cells = []
            for h in range(MAX_HORIZON + 1):
                pt = next((p for p in row_irfs if p['horizon'] == h), None)
                if pt and not np.isnan(pt['coef']):
                    cells.append(f"{pt['coef']:.4f}{stars(pt['p'])}")
                else:
                    cells.append('')
            f.write(f"| {y_label} | {' | '.join(cells)} |\n")
        f.write("\n")

        # OECD table
        f.write("## C: OECD Subsample — Reduced Form\n\n")
        f.write("| Outcome | h=0 | h=1 | h=2 | h=3 | h=4 | h=5 |\n")
        f.write("|---------|-----|-----|-----|-----|-----|-----|\n")

        for y_var, y_label in outcomes.items():
            row_irfs = [pt for pt in all_irfs
                        if pt['outcome'] == y_var and pt['spec'] == 'C: OECD Reduced Form']
            if not row_irfs:
                continue
            cells = []
            for h in range(MAX_HORIZON + 1):
                pt = next((p for p in row_irfs if p['horizon'] == h), None)
                if pt and not np.isnan(pt['coef']):
                    cells.append(f"{pt['coef']:.4f}{stars(pt['p'])}")
                else:
                    cells.append('')
            f.write(f"| {y_label} | {' | '.join(cells)} |\n")
        f.write("\n")

        # ASCII plots
        f.write("## Impulse Response Plots\n\n")
        f.write("```\n")
        for plot in all_plots:
            f.write(plot + '\n')
        f.write("```\n")

        f.write("\n*p<0.1, **p<0.05, ***p<0.01. Controls: fiscal_bal_gdp, nfa_gdp_lag, ")
        f.write("log_rel_opw, kaopen. PanelGLS with Cochrane-Orcutt AR(1) correction.*\n")

    print(f"\nSaved: {OUT_TABLES / 'local_projections_results.csv'}")
    print(f"Saved: {OUT_TABLES / 'local_projections.md'}")
    print("\nPhase 7 complete.")


if __name__ == '__main__':
    main()
